from torchvision.transforms import transforms
from PIL import Image
from utils.scale_resize import ScaleResize
import numpy as np
import cv2
import os
import albumentations as A
from albumentations import pytorch as AT
import pandas as pd

input_size = 224

train_transformer = transforms.Compose([
    transforms.RandomRotation(degrees=180, expand=True),
    transforms.Resize(input_size),
    transforms.RandomAffine(10),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

show_transforms = transforms.Compose([
    # transforms.RandomRotation(degrees=180, expand=False),
    # transforms.Resize(224),
    # transforms.RandomResizedCrop(224),
    transforms.Resize(int(input_size * (256 / 224))),  # 长宽比固定不动，最小边缩放到256
    transforms.CenterCrop(input_size),
    # transforms.RandomAffine(10),
    # transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomVerticalFlip(),
])

scale_transform = transforms.Compose([
    # ScaleResize((int(input_size), int(input_size)))
    ScaleResize((int(input_size * (256 / 224)), int(input_size * (256 / 224)))),
    transforms.RandomCrop(input_size),
    transforms.GaussianBlur(kernel_size=5),
    # transforms.RandomHorizontalFlip(),
    # RandomRotate(15, 0.3),
    # RandomGaussianBlur(),
    # transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

albu_transform = {
    'train': A.Compose([
        # A.Resize(input_size, input_size),  # change scale
        A.RandomResizedCrop(input_size, input_size),  # stay scale
        # A.RandomCrop(input_size, input_size),
        A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.5),
        A.RandomRotate90(),
        A.RandomBrightnessContrast(),
        A.OneOf([
            A.GaussianBlur(),
            A.MedianBlur(blur_limit=3),
            # A.MotionBlur(),  # 运动模糊
        ]),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30),
        A.Normalize(),  # default imagenet std and mean
        AT.ToTensorV2(p=1.0)  # include HWC -> CHW
    ]),
    'val': A.Compose([
        # A.LongestMaxSize((int(input_size * (256 / 224)))),
        # A.PadIfNeeded(input_size, input_size),  # 默认反射填充  零填充 border_mode=cv2.BORDER_CONSTANT
        A.Resize((int(input_size * (256 / 224))), (int(input_size * (256 / 224)))),
        A.RandomCrop(input_size, input_size),
        A.ShiftScaleRotate(),

        # A.OneOf([
        #     A.CoarseDropout(),
        #     A.GridDropout(),
        # ]),
        # A.CenterCrop(input_size, input_size),
        # A.Normalize(),  # default imagenet std and mean
        # AT.ToTensorV2(p=1.0)  # include HWC -> CHW
    ])
}

dataset_path = '../Dataset'
train_csv = pd.read_csv(os.path.join(dataset_path, 'test.csv'))
random_img_path = os.path.join(os.path.join(dataset_path, 'test'), train_csv.sample(1)['image'].values[0])
print(random_img_path)
image = cv2.imread(random_img_path)
print(image.shape)
cv2.imshow('origin-img', image)
image_t = albu_transform['val'](image=image)['image']
print(image_t.shape)
cv2.imshow('trans-img', image_t)
cv2.waitKey()
